#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@File    :   datasets.py
@Time    :   2024/01/05 11:19:36
@Author  :   Robert
@Version :   1.0
@Contact :   robert.wu@tomra.com
@Desc    :   处理数据集
'''

import os
import zipfile
import random
import sys
 
def unzip_file(zip_path, extract_dir):
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_dir)
 

if __name__ == '__main__':
    data = './data/'
    dataset = 'images.zip'
    dataset_dir = './data/images/'
    label_file = './data/label_list.txt'
    train = './data/train_list.txt'
    valid = './data/valid_list.txt'
    data_list = []
    if not os.path.exists(dataset_dir):
        unzip_file(data + dataset, data)
        
    if os.path.exists(label_file):
        sys.exit(0)
    labels = os.listdir(dataset_dir)
    file = open(label_file, 'w')
    for i in range(len(labels)):
        images = os.listdir(dataset_dir + labels[i])
        for image in images:
            data_list.append('images/' + labels[i] + '/' + image + ' ' + str(i) + '\n')
        file.write('{} {}\n'.format(i, labels[i]))
    file.close()
    random.shuffle(data_list)
    train_size = int(len(data_list) * 0.8)
    with open(train, 'w') as f:
        for line in data_list[:train_size]:
            f.write(line)
    with open(valid, 'w') as f:
        for line in data_list[train_size:]:
            f.write(line)
        
